import numpy as np

from 工具.数学函数 import 计算双弯曲
from 工具.特征 import 训练集预处理


class 多层感知机:
    def __init__(self, 数据, 标签, 行列数, 是否归一化=False):
        self.数据 = 训练集预处理(数据, 是否归一化=是否归一化)
        self.标签 = 标签
        self.行列数 = 行列数  # 784 25 10
        self.是否归一化 = 是否归一化
        # θ,这里为权重参数
        self.西塔值矩阵 = self.初始化西塔值矩阵(行列数)

    def 预测(self, 测试集):
        self.数据 = 训练集预处理(测试集, 是否归一化=self.是否归一化)
        样本数 = 测试集.shape[0]
        预测结果 = self.前向传播()

        return np.argmax(预测结果, axis=1).reshape((样本数, 1))

    def 训练(self, 最大迭代次数=1000, 阿尔法=0.1):
        self.西塔值矩阵 = self.展平西塔值矩阵(self.西塔值矩阵)
        历史损失值矩阵 = self.梯度下降(最大迭代次数, 阿尔法)
        return 历史损失值矩阵

    def 梯度下降(self, 最大迭代次数, 阿尔法):
        历史损失值矩阵 = []
        for _ in range(最大迭代次数):
            损失值矩阵 = self.计算损失值矩阵()
            历史损失值矩阵.append(损失值矩阵)
            展平的西塔值梯度矩阵 = self.计算梯度()
            self.西塔值矩阵 = self.展平西塔值矩阵(self.西塔值矩阵)
            self.西塔值矩阵 = self.西塔值矩阵 - 阿尔法 * 展平的西塔值梯度矩阵

        return 历史损失值矩阵

    def 计算梯度(self):
        self.西塔值矩阵 = self.合拢西塔值矩阵(self.西塔值矩阵)
        西塔值梯度矩阵 = self.反向传播()
        展平的西塔值梯度矩阵 = self.展平西塔值矩阵(西塔值梯度矩阵)
        return 展平的西塔值梯度矩阵

    def 反向传播(self):
        列数 = len(self.行列数)
        (样本数, 特征数) = self.数据.shape
        标签数 = self.行列数[-1]
        德尔塔字典 = {}
        for 列索引 in range(列数 - 1):
            输入层数量 = self.行列数[列索引]
            输出层数量 = self.行列数[列索引 + 1]
            德尔塔字典[列索引] = np.zeros((输出层数量, 输入层数量 + 1))
        for 样本索引 in range(样本数):
            输入层 = {}
            # 激活层相当于隐藏层，这里的“层”是矩阵
            激活层 = {}
            当前激活层 = self.数据[样本索引, :].reshape((特征数, 1))
            激活层[0] = 当前激活层

            for 列索引 in range(列数 - 1):
                当前列西塔值矩阵 = self.西塔值矩阵[列索引]
                当前输入层 = np.dot(当前列西塔值矩阵, 当前激活层)
                当前激活层 = np.vstack((np.array([[1]]), 计算双弯曲(当前输入层)))
                输入层[列索引 + 1] = 当前输入层
                激活层[列索引 + 1] = 当前激活层
            激活的输出层 = 当前激活层[1:, :]

            当前德尔塔字典 = {}
            标签矩阵 = np.zeros((标签数, 1))
            标签矩阵[self.标签[样本索引][0]] = 1
            当前德尔塔字典[列数 - 1] = 激活的输出层 - 标签矩阵

            for 列索引 in range(列数 - 2, 0, -1):
                当前层西塔值矩阵 = self.西塔值矩阵[列索引]
                下一个德尔塔 = 当前德尔塔字典[列索引 + 1]
                当前输入层 = 输入层[列索引]
                当前输入层 = np.vstack((np.array((1)), 当前输入层))
                当前德尔塔字典[列索引] = np.dot(当前层西塔值矩阵.T, 下一个德尔塔) * 计算双弯曲(当前输入层)
                当前德尔塔字典[列索引] = 当前德尔塔字典[列索引][1:, :]

            for 列索引 in range(列数 - 1):
                当前德尔塔 = np.dot(当前德尔塔字典[列索引 + 1], 激活层[列索引].T)
                德尔塔字典[列索引] = 德尔塔字典[列索引] + 当前德尔塔

        for 列索引 in range(列数 - 1):
            德尔塔字典[列索引] = 德尔塔字典[列索引] * (1 / 样本数)

        return 德尔塔字典

    def 计算损失值矩阵(self):
        列数 = len(self.行列数)
        样本数 = self.数据.shape[0]
        标签数 = self.行列数[-1]

        预测值矩阵 = self.前向传播()
        标签矩阵 = np.zeros((样本数, 标签数))
        for 样本索引 in range(样本数):
            标签矩阵[样本索引][self.标签[样本索引][0]] = 1

        是一的 = np.sum(np.log(预测值矩阵[标签矩阵 == 1]))
        不是一的 = np.sum(np.log(1 - 预测值矩阵[标签矩阵 == 0]))

        损失值矩阵 = (-1 / 样本数) * (是一的 + 不是一的)

        return 损失值矩阵

    def 前向传播(self):
        列数 = len(self.行列数)
        样本数 = self.数据.shape[0]
        # 激活是指使用激活函数就行计算
        激活的输入层矩阵 = self.数据

        西塔值矩阵 = self.合拢西塔值矩阵(self.西塔值矩阵)
        for 列索引 in range(列数 - 1):
            激活的输出层矩阵 = 计算双弯曲(np.dot(激活的输入层矩阵, 西塔值矩阵[列索引].T))
            激活的输出层矩阵 = np.hstack((np.ones((样本数, 1)), 激活的输出层矩阵))
            激活的输入层矩阵 = 激活的输出层矩阵

        return 激活的输入层矩阵[:, 1:]

    def 展平西塔值矩阵(self, 合拢的西塔值矩阵):
        西塔值列数 = len(合拢的西塔值矩阵)
        展平的西塔值矩阵 = np.array([])
        for 索引 in range(西塔值列数):
            展平的西塔值矩阵 = np.hstack((展平的西塔值矩阵, 合拢的西塔值矩阵[索引].flatten()))

        return 展平的西塔值矩阵

    def 合拢西塔值矩阵(self, 展平的西塔值矩阵):
        列数 = len(self.行列数)
        # 合拢的西塔值矩阵 = np.array([])
        合拢的西塔值矩阵 = {}
        当前轮移位置 = 0
        for 列索引 in range(列数 - 1):
            输入层数量 = self.行列数[列索引]
            输出层数量 = self.行列数[列索引 + 1]

            西塔值矩阵宽度 = 输入层数量 + 1
            西塔值矩阵高度 = 输出层数量
            西塔值数量 = 西塔值矩阵宽度 * 西塔值矩阵高度
            开始位置 = 当前轮移位置
            结尾位置 = 当前轮移位置 + 西塔值数量
            某一深度内的西塔值数量 = 展平的西塔值矩阵[开始位置:结尾位置]
            合拢的西塔值矩阵[列索引] = 某一深度内的西塔值数量.reshape((西塔值矩阵高度, 西塔值矩阵宽度))
            当前轮移位置 = 结尾位置

        return 合拢的西塔值矩阵

    @staticmethod
    def 初始化西塔值矩阵(行列数):
        # 可以改为列数
        列数 = len(行列数)
        西塔值矩阵 = {}
        for 列索引 in range(列数 - 1):
            # 输入层，输出层是指数据的输入，最终判断结果的输出，中间还有激活层或者说隐藏层，它们在矩阵中是一列列数据
            输入层数量 = 行列数[列索引]
            输出层数量 = 行列数[列索引 + 1]
            # 这里需要考虑到偏置项，记住一点偏置的个数跟输出的结果是一致的
            西塔值矩阵[列索引] = np.random.rand(输出层数量, 输入层数量 + 1) * 0.05

        return 西塔值矩阵
